{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dynamic Programming\n",
    "\n",
    "We will be using dynamic programming on sample 4x4 grid world and study **Policy Evaluation** \n",
    "\n",
    "![GridWorld](./images/gridworld.png \"Grid World\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Value Iteration \n",
    "\n",
    "Value Iteration is carried out by repeatedly applying equation (3.8) from text in a loop till change in state values fall below a threshold. Pseudo code for the algorithm is given in Fig 3-5 in the text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initial imports and enviroment setup\n",
    "import numpy as np\n",
    "import sys\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sns.set()\n",
    "\n",
    "# create grid world invironment\n",
    "from gridworld import GridworldEnv\n",
    "env = GridworldEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom print to show state values inside the grid\n",
    "def grid_print(V, k=None):\n",
    "    ax = sns.heatmap(V.reshape(env.shape),\n",
    "                     annot=True, square=True,\n",
    "                     cbar=False, cmap='Blues',\n",
    "                     xticklabels=False, yticklabels=False)\n",
    "\n",
    "    if k:\n",
    "        ax.set(title=\"State values after K = {0}\".format(k))\n",
    "    else:\n",
    "        ax.set(title=\"State Values\".format(k))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Value Iteration\n",
    "def value_iteration(env, discount_factor=1.0, theta=0.00001):\n",
    "    \"\"\"\n",
    "    Varry out Value iteration given an environment and a full description\n",
    "    of the environment's dynamics.\n",
    "\n",
    "    Args:\n",
    "        env: OpenAI env. env.P -> transition dynamics of the environment.\n",
    "            env.P[s][a] [(prob, next_state, reward, done)].\n",
    "            env.nS is number of states in the environment.\n",
    "            env.nA is number of actions in the environment.\n",
    "        discount_factor: Gamma discount factor.\n",
    "        theta: tolernace level to stop the iterations\n",
    "\n",
    "    Returns:\n",
    "        policy: [S, A] shaped matrix representing optimal policy.\n",
    "        value : [S] length vector representing optimal value\n",
    "    \"\"\"\n",
    "\n",
    "    def argmax_a(arr):\n",
    "        \"\"\"\n",
    "        Return idx of max element in an array.\n",
    "        \"\"\"\n",
    "        max_idx = []\n",
    "        max_val = float('-inf')\n",
    "        for idx, elem in enumerate(arr):\n",
    "            if elem == max_val:\n",
    "                max_idx.append(idx)\n",
    "            elif elem > max_val:\n",
    "                max_idx = [idx]\n",
    "                max_val = elem\n",
    "        return max_idx\n",
    "\n",
    "    optimal_policy = np.zeros([env.nS, env.nA])\n",
    "    V = np.zeros(env.nS)\n",
    "    V_new = np.copy(V)\n",
    "\n",
    "    while True:\n",
    "        delta = 0\n",
    "        # For each state, perform a \"greedy backup\"\n",
    "        for s in range(env.nS):\n",
    "            q = np.zeros(env.nA)\n",
    "            # Look at the possible next actions\n",
    "            for a in range(env.nA):\n",
    "                # For each action, look at the possible next states\n",
    "                # to calculate q[s,a]\n",
    "                for prob, next_state, reward, done in env.P[s][a]:\n",
    "                    # Calculate the value for each action as per backup diagram\n",
    "                    if not done:\n",
    "                        q[a] += prob*(reward + discount_factor * V[next_state])\n",
    "                    else:\n",
    "                        q[a] += prob * reward\n",
    "\n",
    "            # find the maximum value over all possible actions\n",
    "            # and store updated state value\n",
    "            V_new[s] = q.max()\n",
    "            # How much our value function changed (across any states)\n",
    "            delta = max(delta, np.abs(V_new[s] - V[s]))\n",
    "\n",
    "        V = np.copy(V_new)\n",
    "\n",
    "        # Stop if change is below a threshold\n",
    "        if delta < theta:\n",
    "            break\n",
    "\n",
    "    # V(s) has optimal values. Use these values and one step backup\n",
    "    # to calculate optimal policy\n",
    "    for s in range(env.nS):\n",
    "        q = np.zeros(env.nA)\n",
    "        # Look at the possible next actions\n",
    "        for a in range(env.nA):\n",
    "            # For each action, look at the possible next states\n",
    "            # and calculate q[s,a]\n",
    "            for prob, next_state, reward, done in env.P[s][a]:\n",
    "\n",
    "                # Calculate the value for each action as per backup diagram\n",
    "                if not done:\n",
    "                    q[a] += prob * (reward + discount_factor * V[next_state])\n",
    "                else:\n",
    "                    q[a] += prob * reward\n",
    "\n",
    "        # find the optimal actions\n",
    "        # We are returning stochastic policy which will assign equal\n",
    "        # probability to all those actions which are equal to maximum value\n",
    "        best_actions = argmax_a(q)\n",
    "        optimal_policy[s, best_actions] = 1.0 / len(best_actions)\n",
    "\n",
    "    return optimal_policy, V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal Policy\n",
      " [[0.25 0.25 0.25 0.25]\n",
      " [0.   0.   0.   1.  ]\n",
      " [0.   0.   0.   1.  ]\n",
      " [0.   0.   0.5  0.5 ]\n",
      " [1.   0.   0.   0.  ]\n",
      " [0.5  0.   0.   0.5 ]\n",
      " [0.25 0.25 0.25 0.25]\n",
      " [0.   0.   1.   0.  ]\n",
      " [1.   0.   0.   0.  ]\n",
      " [0.25 0.25 0.25 0.25]\n",
      " [0.   0.5  0.5  0.  ]\n",
      " [0.   0.   1.   0.  ]\n",
      " [0.5  0.5  0.   0.  ]\n",
      " [0.   1.   0.   0.  ]\n",
      " [0.   1.   0.   0.  ]\n",
      " [0.25 0.25 0.25 0.25]]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAD1CAYAAACr6uKwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAS40lEQVR4nO3de1RVdd7H8c95OAfUgeTIQBeFUEi8NFYI2aSOiKxJE5kANbFIFFGQhyxXrBxniZmulpY3FrFyYiIzRfCukKl5jWgSpXSVdFFHZTQ9MpIiAofb7/nD6awotQsbfl94Pq//OBw5n9i93Qdq7W1SSikQkTj/o3sAEd0c4yQSinESCcU4iYRinERCMU4ioRingY4ePYrY2FiMGTMG4eHhmDp1Kk6cOOH4/JQpU1BRUfGzX+eXPu97xcXFCA4ORm1tbbPH6+rqMGjQIJSWlt7yzwYEBPyq16K2wzgNUldXh+nTp2P27NnIz89HQUEBxowZg4SEBDQ2NgIAioqKftHX+qXP+97DDz+MO++8E7t37272+O7du9GzZ0/069fvV309koFxGqSmpgbXrl1DdXW147GIiAjMnTsXjY2N+Otf/woAmDRpEi5cuID9+/djwoQJiIqKQkhICFasWAEAP3mezWZDcnIyoqKiMGbMGKxcufKmrx8TE4NNmzY1eywvLw9PPfUUTp8+jcmTJ2P8+PEYPnw4kpKSYLfbmz138+bNmD59+k0/rqurwyuvvILIyEhERERg9uzZqKqqAgDk5OQgIiIC0dHRmDhxIk6ePNmC7yI1o8gw2dnZasCAASo0NFS98MILasOGDaq6utrx+d69e6vLly+rpqYm9fTTT6vTp08rpZS6ePGi6tu3r7p8+XKz5ymlVGxsrNq7d69SSqna2loVGxur3nvvvZ+89rVr11RgYKAqKytTSil1+vRpNXjwYGW329WiRYvU1q1blVJK1dXVqfDwcLVz585mr7Vp0yY1bdo0x9f74ccZGRlq0aJFqqmpSSml1NKlS9W8efNUQ0OD6t+/v7LZbEoppbZs2aJyc3ON+WaSMuv+y6EjmTx5MsaNG4fDhw/j8OHDyMrKQlZWFjZu3Ag3NzfH80wmE1auXIkDBw6goKAAp06dglIKNTU1zb5edXU1Dh8+jKtXryI9Pd3x2FdffYXHH3+82XNdXV0RERGBzZs3Y+bMmcjLy8PYsWPh7OyM1NRUFBUVISsrC2fOnMGlS5eaneF/zoEDB3Dt2jV8/PHHAID6+np4eHjAyckJI0eOxIQJExASEoIhQ4Zg2LBhv/XbRz/COA1SUlKCzz77DFOnTsXw4cMxfPhwzJo1C+Hh4SgqKsLIkSMdz62urkZkZCTCwsIQFBSE6Oho7NmzB+pH/5tzU1MTlFLIzc1F586dAQAVFRVwcXG56YaJEyciISEBiYmJyM/Px8aNGwEAs2bNQmNjI0aNGoWQkBBcuHDhJ69lMpmaPVZfX99sx5w5cxzhXb9+3fG2eMmSJfjmm2/w8ccf480338S2bdscf5FQy/BnToN069YNb7zxBo4cOeJ4rLy8HFVVVejduzcAwMnJCQ0NDTh79iyqqqrw3HPPITQ0FIcOHUJdXR2ampqaPc/V1RUPPvgg3n77bQBAZWUlYmJisHfv3ptuuO++++Dt7Y2lS5ciMDAQd911FwDgo48+QnJysuNse+zYMccvqX64/8SJE7Db7aivr8euXbscnxsyZAjWrl3r2Dh37lwsW7YMFRUVGDZsGNzd3REXF4fnnnsOn3/+uUHfUeKZ0yA9e/ZEZmYmli9fjosXL8LFxQVubm545ZVX0KtXLwDAyJEjERsbi/T0dISEhGDUqFFwdnZG79694e/vj7Nnz8LHx8fxvIyMDCxZsgQLFizAmDFjUFdXh/DwcERERNxyx8SJE/H8889j1apVjseef/55JCcno0uXLnB1dUVwcDDKysqa/bnBgwcjODgYo0aNgqenJwYNGoSvv/4aADBjxgwsXrwYkZGRaGxsRN++fTF79my4uroiKSkJcXFx6NSpE5ycnLBw4ULjv7n/T5nUj9/fEJEIfFtLJBTjJBKKcRIJxTiJhLrtb2s7P/S/bbWjTT3w5DjdE1rFC6N6657QasL73617QqvpdIsKeeYkEopxEgnFOImEYpxEQjFOIqEYJ5FQjJNIKMZJJBTjJBKKcRIJxTiJhGKcREIxTiKhGCeRUIyTSCjGSSQU4yQSinESCcU4iYRinERCMU4ioRgnkVAibmQ0ckh/vJwSARdnM744cR6J83Nw7Xqt7lmGmjs6AKfKryOn+JzuKYY4+uFuFObnAQCcXTph9OQU9PDro3mVMdatXYP1eetgMpng7e2NtPkL4eHh0eY7tJ85f291xd/nP42Y1H/ggcgFOH3uMhY8e+u7aLU3vh5d8HrMAIQGeOqeYpjyb8vw/pqVmDTnVaS89hZComKRsyRN9yxDlB7/AqtXZWP12lxs3lYAn3t9kZmh536j2uMMe6QPSo6fxamycgDAmxsKMWFUsOZVxokOvAfbj13Evq/LdU8xjNlsQWRiKu6w3jibdPcLQNWVCjQ01P/Mn5SvX//7sX3HLri5ucFut+OSzQZ3d3ctW7TH2eMuK87Zrjg+Pn/pCrq6dYbb7zrpG2WgpR+cxO7SS7pnGMrqdTf6BP4RAKCUwo53MtEn6FGYzRbNy4xhsViwb+8e/Dn0TygpOYy/REZp2aE9zh/f7vx7jY1NGtbQr1FXW4Pc5S+hwnYekYmpuucYKnREGA4WHULSjBQkTYt33HW8LWn/hdC/L36H4D/4Oj7u7tUVFVevo7q2Tt+oFkgY6ouh/jfe7hWevIyswjN6BxlkT142vjxSBADoGzQYQSNG493Fc+DZ3Qfx81bA4uyieeFvl5mRjoP79wEAevn5Y/yEGAQODAIAPBEVjYUvz0Nl5VW4u1vbdJf2OPf+80ssmhUJPx9PnCorx9SxQ1Fw4HPds36zrMIzHSbIHwp7cgrCnpwCALDXVCMjNR4PDXsMI8bF6R1mgOSUmUhOmQkA+LTkCF5MnYX1m7bCau2GHQX58Pe/r83DBATEWf5dFaa/tAY5r8XD2WzGv879B1PnrtY9i27jk51bcKXchtLiQpQWFzoej09bhi5uXTUua7nAgUFImJaI+LhnYHZygqeXF5ZnZGrZYlI3+4Hvv3gLwPaFtwBsn3gLQKJ2hnESCcU4iYRinERCMU4ioRgnkVCMk0goxkkkFOMkEopxEgnFOImEYpxEQjFOIqEYJ5FQjJNIKMZJJBTjJBKKcRIJxTiJhGKcREIxTiKhGCeRULe9bi0vIdm+dOTLRxYcv6B7QqsZ+8DNjxvPnERCMU4ioRgnkVCMk0goxkkkFOMkEopxEgnFOImEYpxEQjFOIqEYJ5FQjJNIKMZJJBTjJBKKcRIJxTiJhGKcREIxTiKhGCeRUIyTSCjGSSQU4yQSinESCXXb69a2tbmjA3Cq/Dpyis/pnmKYox/uRmF+HgDA2aUTRk9OQQ+/PppXtdy6tWuwPm8dTCYTvL29kTZ/ITw8PHTPMoSUYybizOnr0QWvxwxAaICn7imGKv+2DO+vWYlJc15FymtvISQqFjlL0nTParHS419g9apsrF6bi83bCuBzry8yM9J1zzKEpGMm4swZHXgPth+7CFulXfcUQ5nNFkQmpuIO640zSne/AFRdqUBDQz3MZovmdb9dv/73Y/uOXbBYLLDb7bhks6F7jx66ZxlC0jETEefSD04CAAb1tGpeYiyr192wet241L5SCjveyUSfoEfbdZjfs1gs2Ld3D+an/Q0WZ2fMSHlW9yRDSDpmIt7WdnR1tTXIXf4SKmznEZmYqnuOYUJHhOFg0SEkzUhB0rR4NDU16Z5kGAnHTMuZM2GoL4b633jbUHjyMrIKz+iY0Sr25GXjyyNFAIC+QYMRNGI03l08B57dfRA/bwUszi6aF/42mRnpOLh/HwCgl58/xk+IQeDAIADAE1HRWPjyPFRWXoW7e/t79yP1mJmUUupWn3xk0cG23NJmv61tq7uM2WuqkZEaj4eGPYYR4+Ja/fXa6i5jn5YcwYups7B+01ZYrd2Qv20rVq/KxoYt21vtNdvqLmNtfcyAW99lTMTPnB3VJzu34Eq5DaXFhSgtLnQ8Hp+2DF3cumpc1jKBA4OQMC0R8XHPwOzkBE8vLyzPyNQ9yxCSjpmoM2db4f052x/en5OIxGCcREIxTiKhGCeRUIyTSCjGSSQU4yQSinESCcU4iYRinERCMU4ioRgnkVCMk0goxkkkFOMkEopxEgnFOImEYpxEQjFOIqEYJ5FQjJNIKMZJJNRtr1vLS0i2Lx358pFL3v9G94RWw0tjErUzjJNIKMZJJBTjJBKKcRIJxTiJhGKcREIxTiKhGCeRUIyTSCjGSSQU4yQSinESCcU4iYRinERCMU4ioRgnkVCMk0goxkkkFOMkEopxEgnFOImEuu2lMdvK0Q93ozA/DwDg7NIJoyenoIdfH82rjLFu7Rqsz1sHk8kEb29vpM1fCA8PD92zWqwjH7PvzR0dgFPl15FTfE7L62s/c5Z/W4b316zEpDmvIuW1txASFYucJWm6Zxmi9PgXWL0qG6vX5mLztgL43OuLzIx03bNarCMfMwDw9eiC12MGIDTAU+sO7WdOs9mCyMRU3GG9cTbp7heAqisVaGioh9ls0byuZfr1vx/bd+yCxWKB3W7HJZsN3Xv00D2rxTryMQOA6MB7sP3YRdgq7Vp3aI/T6nU3rF43rnitlMKOdzLRJ+jRDnGQAcBisWDf3j2Yn/Y3WJydMSPlWd2TWqyjH7OlH5wEAAzqadW6Q/vb2u/V1dYgd/lLqLCdR2Riqu45hgodEYaDRYeQNCMFSdPi0dTUpHuSITryMZNAy5lzT142vjxSBADoGzQYQSNG493Fc+DZ3Qfx81bA4uyiY5YhMjPScXD/PgBALz9/jJ8Qg8CBQQCAJ6KisfDleaisvAp3d71/K/9aHfmYJQz1xVD/G2/RC09eRlbhGb2D/suklFK3+uTGY61/Yxx7TTUyUuPx0LDHMGJcXKu/HtB2NzL6tOQIXkydhfWbtsJq7Yb8bVuxelU2NmzZ3iqv11Y3MtJxzHTcyKitflv7yexhN31c+8+cn+zcgivlNpQWF6K0uNDxeHzaMnRx66pxWcsFDgxCwrRExMc9A7OTEzy9vLA8I1P3rBbryMdMEu1nTh14C8D2pyPfAvBWZ04xvxAiouYYJ5FQjJNIKMZJJBTjJBKKcRIJxTiJhGKcREIxTiKhGCeRUIyTSCjGSSQU4yQSinESCcU4iYRinERCMU4ioRgnkVCMk0goxkkkFOMkEopxEgl120tj1ja05ZS201EvIdmRLx95LG+D7gmtpuaz12/6OM+cREIxTiKhGCeRUIyTSCjGSSQU4yQSinESCcU4iYRinERCMU4ioRgnkVCMk0goxkkkFOMkEopxEgnFOImEYpxEQjFOIqEYJ5FQjJNIKMZJJBTjJBLKrHsAAKxbuwbr89bBZDLB29sbafMXwsPDQ/csQxz9cDcK8/MAAM4unTB6cgp6+PXRvMo4c0cH4FT5deQUn9M9xRAjh/THyykRcHE244sT55E4PwfXrtdq2aL9zFl6/AusXpWN1WtzsXlbAXzu9UVmRrruWYYo/7YM769ZiUlzXkXKa28hJCoWOUvSdM8yhK9HF7weMwChAZ66pxjm91ZX/H3+04hJ/QceiFyA0+cuY8GzEdr2aI+zX//7sX3HLri5ucFut+OSzQZ3d3fdswxhNlsQmZiKO6w33gV09wtA1ZUKNDTUa17WctGB92D7sYvY93W57imGCXukD0qOn8Wpshv/TG9uKMSEUcHa9miPEwAsFgv27d2DP4f+CSUlh/GXyCjdkwxh9bobfQL/CABQSmHHO5noE/QozGaL5mUtt/SDk9hdekn3DEP1uMuKc7Yrjo/PX7qCrm6d4fa7Tlr2iIgTAEJHhOFg0SEkzUhB0rR4NDU16Z5kmLraGuQufwkVtvOITEzVPYduwWQy4WZ3J2ls1PPvopZfCGVmpOPg/n0AgF5+/hg/IQaBA4MAAE9ERWPhy/NQWXkV7u5WHfNaZE9eNr48UgQA6Bs0GEEjRuPdxXPg2d0H8fNWwOLsonnhb5Mw1BdD/W+8PS88eRlZhWf0DmoF/774HYL/4Ov4uLtXV1RcvY7q2jote7TEmZwyE8kpMwEAn5YcwYups7B+01ZYrd2woyAf/v73tcswASDsySkIe3IKAMBeU42M1Hg8NOwxjBgXp3dYC2UVnumQQf7Q3n9+iUWzIuHn44lTZeWYOnYoCg58rm2P9v+UEjgwCAnTEhEf9wzMTk7w9PLC8oxM3bMM8cnOLbhSbkNpcSFKiwsdj8enLUMXt64al9HNlH9XhekvrUHOa/FwNpvxr3P/wdS5q7Xt4S0AOxDeArB94i0AidoZxkkkFOMkEopxEgnFOImEYpxEQjFOIqEYJ5FQjJNIKMZJJBTjJBKKcRIJxTiJhGKcREIxTiKhGCeRUIyTSCjGSSQU4yQSinESCcU4iYRinERCMU4ioW573Voi0odnTiKhGCeRUIyTSCjGSSQU4yQSinESCfV/m/Md7B0I3bcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Run policy iteration on Grid world\n",
    "pi_star, V_star = value_iteration(env)\n",
    "\n",
    "# Print Optimal policy\n",
    "print(\"Optimal Policy\\n\", pi_star)\n",
    "\n",
    "# Print optimal state values\n",
    "grid_print(V_star.reshape(env.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conclusion\n",
    "\n",
    "Similar to Policy Iteration in `listing3_3`, we see that optimal state values are negative of number of steps required to each the closest terminal state. As the reward is -1 for each time step till agent reaches terminal state, the optimal policy would take the agent to terminal state in minimal number of possible steps. For some states, more than one action could lead to same number of steps to reach terminal state. For example, look at top right state with value=-3, it takes 3 steps to reach the terminal state at top-left or terminal state at bottom-right. In other words, the state values is negative of Manhattan distance between the state and nearest terminal state."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
